import autograd.numpy as np


class ContinuousFactor:

    def __init__(self, parent, model, orientation, depth, subtree=False, subtree_loading=None):
        self.orientation = orientation
        self.model = model
        self.parent = parent
        self.children = []
        self.fixed = False
        self.depth = depth
        self.mu = None
        self.z = None
        if self.parent is None and not self.model.nn:
            self.z = 0.5 * np.ones(self.model.N)
            self.mu = np.zeros(self.model.D)
        self.subtree_loading = subtree_loading
        self.subtree = subtree

    def parent(self):
        return self.parent

    def left(self):
        return self.children[0]

    def right(self):
        return self.children[1]

    def ancestors(self):
        if self.parent is not None:
            ancestors = [self.parent]
            cur_node = ancestors[0]
            while cur_node.parent is not None:
                ancestors.append(cur_node.parent)
                cur_node = cur_node.parent
            return ancestors
        return []

    def compound_parent_score(self):
        if self.parent is not None:
            if self.orientation:
                parent_compound = self.parent.compound_parent_score() * self.parent.z
            else:
                parent_compound = self.parent.compound_parent_score() * (1 - self.parent.z)
        else:
            parent_compound = np.ones(self.model.N)
        if self.subtree and self.subtree_loading is not None:
            parent_compound = parent_compound * self.subtree_loading
        return parent_compound

    def compound_score(self):
        parent_compound = self.compound_parent_score()
        compound = parent_compound * self.z
        return compound

    def partial(self):
        compound_score = self.compound_score()
        partial = np.dot(compound_score.reshape((compound_score.size, 1)), self.mu.reshape((1, -1)))
        return partial

    def split(self, num_children=2):
        left = self.model.blank_factor(self, False, self.depth + 1)
        right = self.model.blank_factor(self, True, self.depth + 1)
        self.children = [left, right]
        return self.children

    def create_subtrees(self, num_subtrees, subtree_loadings=None):
        for i in range(num_subtrees):
            if subtree_loadings is not None:
                self.children.append(self.model.blank_factor(self, True, self.depth + 1, subtree=True, subtree_loading=subtree_loadings[i]))
            else:
                self.children.append(self.model.blank_factor(self, True, self.depth + 1, subtree=True))
        return self.children

    def set_z(self, z):
        if self.parent is None:
            if not self.model.nn: # if not NN, assumed 0-centered data 
                return
            self.z = z
        else:
            self.z = z

    def set_mu(self, mu):
        if self.parent is None:
            if not self.model.nn:  # if not NN, assumed 0-centered data 
                self.mu = np.zeros(self.model.D)
                return
            self.mu = mu
        else:
            self.mu = mu
